Skip to content

Commit c8a2856

Browse files
committed
feat: support for applying min-snr weighting for faster convergence.
1 parent c73fdba commit c8a2856

File tree

1 file changed

+58
-5
lines changed

1 file changed

+58
-5
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151

5252
logger = get_logger(__name__, log_level="INFO")
5353

54+
DATASET_NAME_MAPPING = {
55+
"lambdalabs/pokemon-blip-captions": ("image", "text"),
56+
}
57+
5458

5559
def parse_args():
5660
parser = argparse.ArgumentParser(description="Simple example of a training script.")
@@ -193,6 +197,13 @@ def parse_args():
193197
parser.add_argument(
194198
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
195199
)
200+
parser.add_argument(
201+
"--_snr_gamma",
202+
type=float,
203+
default=None,
204+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
205+
"More details here: https://arxiv.org/abs/2303.09556.",
206+
)
196207
parser.add_argument(
197208
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
198209
)
@@ -325,9 +336,32 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
325336
return f"{organization}/{model_id}"
326337

327338

328-
dataset_name_mapping = {
329-
"lambdalabs/pokemon-blip-captions": ("image", "text"),
330-
}
339+
def expand_tensor(arr, timesteps, broadcast_shape):
340+
"""
341+
Extract values from a 1-D numpy array for a batch of indices.
342+
Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
343+
"""
344+
res = arr.to(device=timesteps.device)[timesteps].float()
345+
while len(res.shape) < len(broadcast_shape):
346+
res = res[..., None]
347+
return res.expand(broadcast_shape)
348+
349+
350+
def compute_snr(noise_scheduler):
351+
"""
352+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
353+
"""
354+
alphas_cumprod = noise_scheduler.alphas_cumprod
355+
sqrt_alphas_cumprod = alphas_cumprod**0.5
356+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
357+
358+
def fn(timesteps):
359+
alpha = expand_tensor(sqrt_alphas_cumprod, timesteps, timesteps.shape)
360+
sigma = expand_tensor(sqrt_one_minus_alphas_cumprod, timesteps, timesteps.shape)
361+
snr = (alpha / sigma) ** 2
362+
return snr
363+
364+
return fn
331365

332366

333367
def main():
@@ -476,6 +510,9 @@ def load_model_hook(models, input_dir):
476510
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
477511
)
478512

513+
if args.snr_gamma is not None:
514+
snr_fn = compute_snr(noise_scheduler)
515+
479516
# Initialize the optimizer
480517
if args.use_8bit_adam:
481518
try:
@@ -526,7 +563,7 @@ def load_model_hook(models, input_dir):
526563
column_names = dataset["train"].column_names
527564

528565
# 6. Get the column names for input/target.
529-
dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
566+
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
530567
if args.image_column is None:
531568
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
532569
else:
@@ -734,7 +771,23 @@ def collate_fn(examples):
734771

735772
# Predict the noise residual and compute loss
736773
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
737-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
774+
775+
if args.snr_gamma is None:
776+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
777+
else:
778+
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
779+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
780+
# This is discussed in Section 4.2 of the same paper.
781+
snr = snr_fn(timesteps)
782+
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
783+
dim=1
784+
)[0]
785+
# We first calculate the original loss. Then we mean over the non-batch dimensions and
786+
# rebalance the sample-wise losses with their respective loss weights.
787+
# Finally, we take the mean of the rebalanced loss.
788+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
789+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
790+
loss = (mse_loss_weights * loss).mean()
738791

739792
# Gather the losses across all processes for logging (if we use distributed training).
740793
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()

0 commit comments

Comments
 (0)